import torch,random
import torch.nn as nn
from torch.nn import functional as F
from moco.utils import *
from braincog.base.node import *
import torch.nn.init as init
from torchvision.models.video import r2plus1d_18, R2Plus1D_18_Weights, VideoResNet
from timm.models import register_model


class MoCo(nn.Module):
    """
    Build a MoCo model with: a query encoder, a key encoder, and a queue
    https://arxiv.org/abs/1911.05722
    """

    def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07, mlp=False, args=None,**kwargs):
        """
        dim: feature dimension (default: 128)
        K: queue size; number of negative keys (default: 65536)
        m: moco momentum of updating key encoder (default: 0.999)
        T: softmax temperature (default: 0.07)
        """
        super(MoCo, self).__init__()

        self.K = K
        self.m = m
        self.T = T

        # create the encoders
        # num_classes is the output fc dimension
        self.encoder_q = base_encoder(num_classes=dim)
        self.encoder_k = base_encoder(num_classes=dim)

        exchange_conv(self.encoder_q,args)
        exchange_conv(self.encoder_k,args)
                    
        if mlp:  # hack: brute-force replacement
            dim_mlp = self.encoder_q.fc.weight.shape[1]
            self.encoder_q.fc = nn.Sequential(
                nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc
            )
            self.encoder_k.fc = nn.Sequential(
                nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc
            )

        for param_q, param_k in zip(
            self.encoder_q.parameters(), self.encoder_k.parameters()
        ):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        # create the queue
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = nn.functional.normalize(self.queue, dim=0)

        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """
        Momentum update of the key encoder
        """
        for param_q, param_k in zip(
            self.encoder_q.parameters(), self.encoder_k.parameters()
        ):
            param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        # gather keys before updating queue
        keys = concat_all_gather(keys)

        batch_size = keys.shape[0]

        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0  # for simplicity

        # replace the keys at ptr (dequeue and enqueue)
        self.queue[:, ptr : ptr + batch_size] = keys.T
        ptr = (ptr + batch_size) % self.K  # move pointer

        self.queue_ptr[0] = ptr

    @torch.no_grad()
    def _batch_shuffle_ddp(self, x):
        """
        Batch shuffle, for making use of BatchNorm.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # random shuffle index
        idx_shuffle = torch.randperm(batch_size_all).cuda()

        # broadcast to all gpus
        torch.distributed.broadcast(idx_shuffle, src=0)

        # index for restoring
        idx_unshuffle = torch.argsort(idx_shuffle)

        # shuffled index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this], idx_unshuffle

    @torch.no_grad()
    def _batch_unshuffle_ddp(self, x, idx_unshuffle):
        """
        Undo batch shuffle.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # restored index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this]

    def forward(self, im_q, im_k,criterion, *args,**kwargs):
        """
        Input:
            im_q: a batch of query images
            im_k: a batch of key images
        Output:
            logits, targets
        """


        # compute key features
        with torch.no_grad():  # no gradient to keys
            self._momentum_update_key_encoder()  # update the key encoder

            # shuffle for making use of BN
            im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)

            k = self.encoder_k(im_k)  # keys: NxC
            k = nn.functional.normalize(k, dim=1)

            # undo shuffle
            k = self._batch_unshuffle_ddp(k, idx_unshuffle)
            
        # compute query features
        q = self.encoder_q(im_q)  # queries: NxC
        q = nn.functional.normalize(q, dim=1)

        # compute logits
        # Einstein sum is more intuitive
        # positive logits: Nx1
        l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1)
        # negative logits: NxK
        l_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()])

        # logits: Nx(1+K)
        logits = torch.cat([l_pos, l_neg], dim=1)

        # apply temperature
        logits /= self.T

        # labels: positive key indicators
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
        # loss 
        loss = criterion(logits, labels)
        
        # dequeue and enqueue
        self._dequeue_and_enqueue(k)

        return [loss,loss,torch.tensor(0.0)], logits, labels

class MoCo_PredNext_clip(nn.Module):
    """
    Build a MoCo model with: a query encoder, a key encoder, and a queue
    https://arxiv.org/abs/1911.05722
    """

    def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07, mlp=False, pred_dim=512, args=None,**kwargs):
        """
        dim: feature dimension (default: 128)
        K: queue size; number of negative keys (default: 65536)
        m: moco momentum of updating key encoder (default: 0.999)
        T: softmax temperature (default: 0.07)
        """
        super(MoCo_PredNext_clip, self).__init__()

        self.K = K
        self.m = m
        self.T = T

        # create the encoders
        # num_classes is the output fc dimension
        self.encoder_q = base_encoder(num_classes=dim)
        self.encoder_k = base_encoder(num_classes=dim)

        exchange_conv(self.encoder_q,args)
        exchange_conv(self.encoder_k,args)
                    
        if mlp:  # hack: brute-force replacement
            dim_mlp = self.encoder_q.fc.weight.shape[1]
            self.encoder_q.fc = nn.Sequential(
                nn.Linear(dim_mlp, dim_mlp), 
                nn.ReLU(), 
                self.encoder_q.fc
            )
            self.encoder_k.fc = nn.Sequential(
                nn.Linear(dim_mlp, dim_mlp), 
                nn.ReLU(), 
                self.encoder_k.fc
            )

        for param_q, param_k in zip(
            self.encoder_q.parameters(), self.encoder_k.parameters()
        ):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        # create the queue
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = nn.functional.normalize(self.queue, dim=0)

        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

                # build a 2-layer predictor
        # self.predictor_next = nn.Sequential(
        #     nn.BatchNorm1d(dim, affine=False), # output layer normalization
        #     nn.Linear(dim, pred_dim, bias=False),
        #     nn.BatchNorm1d(pred_dim),
        #     nn.ReLU(inplace=True), # hidden layer
        #     nn.Linear(pred_dim, dim) # output layer
        # )
        self.clip_bn=nn.BatchNorm1d(dim, affine=False)
        self.predictor_clip = nn.Sequential(
            nn.Linear(dim, pred_dim, bias=False),
            nn.BatchNorm1d(pred_dim),
            nn.ReLU(inplace=True), # hidden layer
            nn.Linear(pred_dim, dim) # output layer
        )

        self.pred_step=args.pred_step
        self.pred_alpha=args.pred_alpha
        self.encoder_q.sum_output=False
        self.criterion = nn.CosineSimilarity(dim=1)

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """
        Momentum update of the key encoder
        """
        for param_q, param_k in zip(
            self.encoder_q.parameters(), self.encoder_k.parameters()
        ):
            param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        # gather keys before updating queue
        keys = concat_all_gather(keys)

        batch_size = keys.shape[0]

        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0  # for simplicity

        # replace the keys at ptr (dequeue and enqueue)
        self.queue[:, ptr : ptr + batch_size] = keys.T
        ptr = (ptr + batch_size) % self.K  # move pointer

        self.queue_ptr[0] = ptr

    @torch.no_grad()
    def _batch_shuffle_ddp(self, x):
        """
        Batch shuffle, for making use of BatchNorm.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        x= x.contiguous()
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # random shuffle index
        idx_shuffle = torch.randperm(batch_size_all).cuda()

        # broadcast to all gpus
        torch.distributed.broadcast(idx_shuffle, src=0)

        # index for restoring
        idx_unshuffle = torch.argsort(idx_shuffle)

        # shuffled index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this], idx_unshuffle

    @torch.no_grad()
    def _batch_unshuffle_ddp(self, x, idx_unshuffle):
        """
        Undo batch shuffle.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # restored index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this]

    def forward(self, im_q, im_k,criterion,epoch_ratio, *args,**kwargs):
        """
        Input:
            im_q: a batch of query images
            im_k: a batch of key images
        Output:
            logits, targets
        """
        im_q_next_clip=im_q[:,1]
        im_k_next_clip=im_k[:,1]
        im_q=im_q[:,0]
        im_k=im_k[:,0]


        # compute key features
        with torch.no_grad():  # no gradient to keys
            self._momentum_update_key_encoder()  # update the key encoder

            # shuffle for making use of BN
            im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)

            k = self.encoder_k(im_k)  # keys: NxC
            k = nn.functional.normalize(k, dim=1)

            # undo shuffle
            k = self._batch_unshuffle_ddp(k, idx_unshuffle)
            
        # compute query features
        q = self.encoder_q(im_q)  # queries: NxC
        with torch.no_grad():
            q_next_clip = self.encoder_q(im_q_next_clip)  # Shape: [T, batch, C]
        # loss_next=self.pred_next(q,q)
        # loss_next=loss_next*5
        loss_next = torch.tensor(0.0)
        loss_clip=self.pred_clip(self.clip_bn(q),self.clip_bn(q),self.clip_bn(q_next_clip),self.clip_bn(q_next_clip))
        loss_clip=loss_clip*5

        q=q.mean(0)
        q = nn.functional.normalize(q, dim=1)

        # compute logits
        # Einstein sum is more intuitive
        # positive logits: Nx1
        l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1)
        # negative logits: NxK
        l_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()])

        # logits: Nx(1+K)
        logits = torch.cat([l_pos, l_neg], dim=1)

        # apply temperature
        logits /= self.T

        # labels: positive key indicators
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
        # loss 
        loss_full = criterion(logits, labels)
        
        loss=self.pred_alpha*epoch_ratio*(loss_next*0.2+loss_clip*0.8)+(1-self.pred_alpha*epoch_ratio)*loss_full
        # dequeue and enqueue
        self._dequeue_and_enqueue(k)

        return [loss,loss_full,loss_next,loss_clip], logits, labels
    
    def pred_next(self,fq,fk):
        

        # Get dimensions
        time_steps, batch_size, feature_dim = fq.shape
        
        pred_step = self.pred_step
        if pred_step<0:
            pred_step = random.randint(1, time_steps-1)
        current_fq = fq[:-pred_step]  # Shape: [T-1, batch, C]
        next_fq = fq[pred_step:]      # Shape: [T-1, batch, C] 
        current_fk = fk[:-pred_step]  # Shape: [T-1, batch, C]
        next_fk = fk[pred_step:]      # Shape: [T-1, batch, C]
        
        
        
        current_zq = current_fq.reshape(-1, feature_dim)  # Shape: [(T-1)*batch, C]
        next_zq = next_fq.reshape(-1, feature_dim)        # Shape: [(T-1)*batch, C]
        current_zk = current_fk.reshape(-1, feature_dim)  # Shape: [(T-1)*batch, C]
        next_zk = next_fk.reshape(-1, feature_dim)        # Shape: [(T-1)*batch, C]
        
                 
        # Apply predictor to current features to predict next features
        pq = self.predictor_next(current_zq) # Shape: [(T-1)*batch, C]
        pk = self.predictor_next(current_zk) # Shape: [(T-1)*batch, C]
        
        next_zq = next_zq.detach()
        next_zk = next_zk.detach()
        
        # compute loss
        # negative cosine similarity (SimSiam loss)
        loss_next = -0.5*(self.criterion(pq, next_zk).mean() +  self.criterion(pk, next_zq).mean())
        return loss_next
    
    def pred_clip(self,fq,fk,fq_next_clip,fk_next_clip):

        fq_current_clip=fq.mean(0)
        fk_current_clip=fk.mean(0)
        fq_next_clip=fq_next_clip.mean(0)
        fk_next_clip=fk_next_clip.mean(0) 
        p_fq_current_clip=self.predictor_clip(fq_current_clip)
        p_fk_current_clip=self.predictor_clip(fk_current_clip) 

        fq_next_clip=fq_next_clip.detach()
        fk_next_clip=fk_next_clip.detach()
        loss_clip = -0.5*(self.criterion(p_fq_current_clip, fq_next_clip).mean() +  self.criterion(p_fk_current_clip, fk_next_clip).mean())
        return loss_clip

class SimCLR(nn.Module):
    def __init__(self, base_encoder, dim=128, T=0.07, args=None, **kwargs):
        super(SimCLR, self).__init__()
        
        self.T = T
        
        # create the encoder
        self.encoder_q = base_encoder(num_classes=dim)
        exchange_conv(self.encoder_q, args)
        dim_mlp = self.encoder_q.fc.weight.shape[1]
        self.encoder_q.fc = nn.Sequential(
            nn.Linear(dim_mlp, dim_mlp), 
            nn.ReLU(), 
            nn.Linear(dim_mlp, dim)
        )
    
    def forward(self, im_q, im_k, *args, **kwargs):
        # Extract features for both views
        q = self.encoder_q(im_q)  # queries: NxC
        k = self.encoder_q(im_k)  # keys: NxC
        
        # Normalize embeddings
        q = F.normalize(q, dim=1)
        k = F.normalize(k, dim=1)
        
        # Get batch size
        batch_size = q.shape[0]
        
        # Check if distributed training is available
        if torch.distributed.is_available() and torch.distributed.is_initialized():
            # Gather representations from all GPUs
            q_gather = concat_all_gather(q)
            k_gather = concat_all_gather(k)
            
            rank = torch.distributed.get_rank()
            labels = torch.arange(0, batch_size, dtype=torch.long, device=q.device) + rank * batch_size
        else:
            # Single GPU training
            q_gather = q
            k_gather = k
            labels = torch.arange(0, batch_size, dtype=torch.long, device=q.device)
        
        # InfoNCE loss implementation (symmetric)
        q_k_sim = torch.mm(q, k_gather.t()) / self.T
        k_q_sim = torch.mm(k, q_gather.t()) / self.T
        
        loss_q = F.cross_entropy(q_k_sim, labels)
        loss_k = F.cross_entropy(k_q_sim, labels)
        loss = (loss_q + loss_k) / 2.0
        
        return [loss,loss,torch.tensor(0.0)], q_k_sim, labels

class SimCLR_PredNext_clip(nn.Module):
    def __init__(self, base_encoder, dim=128, T=0.07, pred_dim=512, args=None, **kwargs):
        super(SimCLR_PredNext_clip, self).__init__()
        
        self.T = T
        self.pred_dim = pred_dim
        # create the encoder
        self.encoder_q = base_encoder(num_classes=dim)
        exchange_conv(self.encoder_q, args)
        dim_mlp = self.encoder_q.fc.weight.shape[1]
        self.encoder_q.fc = nn.Sequential(
            nn.Linear(dim_mlp, dim_mlp), 
            nn.ReLU(), 
            nn.Linear(dim_mlp, dim)
        )
        # build a 2-layer predictor
        # self.predictor_next = nn.Sequential(
        #     nn.BatchNorm1d(dim, affine=False), # output layer normalization
        #     nn.Linear(dim, pred_dim, bias=False),
        #     nn.BatchNorm1d(pred_dim),
        #     nn.ReLU(inplace=True), # hidden layer
        #     nn.Linear(pred_dim, dim) # output layer
        # )
        self.clip_bn=nn.BatchNorm1d(dim, affine=False)
        self.predictor_clip = nn.Sequential(
            nn.Linear(dim, pred_dim, bias=False),
            nn.BatchNorm1d(pred_dim),
            nn.ReLU(inplace=True), # hidden layer
            nn.Linear(pred_dim, dim) # output layer
        )
        self.pred_step=args.pred_step
        self.pred_alpha=args.pred_alpha
        self.encoder_q.sum_output=False
        self.criterion = nn.CosineSimilarity(dim=1)
    
    def forward(self, im_q, im_k,epoch_ratio,*args, **kwargs):


        im_q_next_clip=im_q[:,1]
        im_k_next_clip=im_k[:,1]
        im_q=im_q[:,0]
        im_k=im_k[:,0]


        # Extract features for both views
        q = self.encoder_q(im_q)  # queries: NxC
        k = self.encoder_q(im_k)  # keys: NxC
        with torch.no_grad():
            q_next_clip = self.encoder_q(im_q_next_clip)  # Shape: [T, batch, C]
            k_next_clip = self.encoder_q(im_k_next_clip)  # Shape: [T, batch, C]
        # loss_next=self.pred_next(q,k)
        # loss_next=loss_next*2.5
        loss_next = torch.tensor(0.0)
        loss_clip=self.pred_clip(self.clip_bn(q),self.clip_bn(k),self.clip_bn(q_next_clip),self.clip_bn(k_next_clip))
        loss_clip=loss_clip

        q=q.mean(0)
        k=k.mean(0)
        # Normalize embeddings
        q = F.normalize(q, dim=1)
        k = F.normalize(k, dim=1)
        
        # Get batch size
        batch_size = q.shape[0]
        
        # Check if distributed training is available
        if torch.distributed.is_available() and torch.distributed.is_initialized():
            # Gather representations from all GPUs
            q_gather = concat_all_gather(q)
            k_gather = concat_all_gather(k)
            
            rank = torch.distributed.get_rank()
            labels = torch.arange(0, batch_size, dtype=torch.long, device=q.device) + rank * batch_size
        else:
            # Single GPU training
            q_gather = q
            k_gather = k
            labels = torch.arange(0, batch_size, dtype=torch.long, device=q.device)
        
        # InfoNCE loss implementation (symmetric)
        q_k_sim = torch.mm(q, k_gather.t()) / self.T
        k_q_sim = torch.mm(k, q_gather.t()) / self.T
        
        loss_q = F.cross_entropy(q_k_sim, labels)
        loss_k = F.cross_entropy(k_q_sim, labels)
        loss_full = (loss_q + loss_k) / 2.0
        
        loss=self.pred_alpha*(loss_next*0.2+loss_clip*0.8)+(1-self.pred_alpha)*loss_full


        return [loss,loss_full,loss_next,loss_clip], q_k_sim, labels

    def pred_next(self,fq,fk):
        

        # Get dimensions
        time_steps, batch_size, feature_dim = fq.shape
        
        pred_step = self.pred_step
        if pred_step<0:
            pred_step = random.randint(1, time_steps-1)
        current_fq = fq[:-pred_step]  # Shape: [T-1, batch, C]
        next_fq = fq[pred_step:]      # Shape: [T-1, batch, C] 
        current_fk = fk[:-pred_step]  # Shape: [T-1, batch, C]
        next_fk = fk[pred_step:]      # Shape: [T-1, batch, C]
        
        
        
        current_zq = current_fq.reshape(-1, feature_dim)  # Shape: [(T-1)*batch, C]
        next_zq = next_fq.reshape(-1, feature_dim)        # Shape: [(T-1)*batch, C]
        current_zk = current_fk.reshape(-1, feature_dim)  # Shape: [(T-1)*batch, C]
        next_zk = next_fk.reshape(-1, feature_dim)        # Shape: [(T-1)*batch, C]
        
                 
        # Apply predictor to current features to predict next features
        pq = self.predictor_next(current_zq) # Shape: [(T-1)*batch, C]
        pk = self.predictor_next(current_zk) # Shape: [(T-1)*batch, C]
        
        next_zq = next_zq.detach()
        next_zk = next_zk.detach()
        
        # compute loss
        # negative cosine similarity (SimSiam loss)
        loss_next = -0.5*(self.criterion(pq, next_zk).mean() +  self.criterion(pk, next_zq).mean())
        return loss_next

    def pred_clip(self,fq,fk,fq_next_clip,fk_next_clip):

        fq_current_clip=fq.mean(0)
        fk_current_clip=fk.mean(0)
        fq_next_clip=fq_next_clip.mean(0)
        fk_next_clip=fk_next_clip.mean(0) 
        p_fq_current_clip=self.predictor_clip(fq_current_clip)
        p_fk_current_clip=self.predictor_clip(fk_current_clip) 

        fq_next_clip=fq_next_clip.detach()
        fk_next_clip=fk_next_clip.detach()
        loss_clip = -0.5*(self.criterion(p_fq_current_clip, fq_next_clip).mean() +  self.criterion(p_fk_current_clip, fk_next_clip).mean())
        return loss_clip

class SimSiam(nn.Module):
    """
    Build a SimSiam model with a single encoder and predictor MLP
    https://arxiv.org/abs/2011.10566
    """
    def __init__(self, base_encoder, dim=2048, pred_dim=512, args=None,**kwargs):
        """
        dim: feature dimension (default: 2048)
        pred_dim: hidden dimension of the predictor (default: 512)
        """
        super(SimSiam, self).__init__()
        
        # create the encoder
        # num_classes is the output fc dimension
        self.encoder_q = base_encoder(num_classes=dim,node_type=eval(args.node_type),threshold=args.threshold,tau=args.tau)
        exchange_conv(self.encoder_q,args)
        # build a 3-layer projector
        prev_dim = self.encoder_q.fc.weight.shape[1]
        self.encoder_q.fc = nn.Sequential(
            nn.Linear(prev_dim, prev_dim, bias=False),
            nn.BatchNorm1d(prev_dim),
            nn.ReLU(inplace=True), # first layer
            nn.Linear(prev_dim, prev_dim, bias=False),
            nn.BatchNorm1d(prev_dim),
            nn.ReLU(inplace=True), # second layer
            nn.Linear(prev_dim, dim), # output layer
            nn.BatchNorm1d(dim, affine=False) # output layer normalization
        )
        self.encoder_q.fc[-2].bias.requires_grad = False # hack: not use bias as it is followed by BN
        
        # build a 2-layer predictor
        self.predictor = nn.Sequential(
            nn.Linear(dim, pred_dim, bias=False),
            nn.BatchNorm1d(pred_dim),
            nn.ReLU(inplace=True), # hidden layer
            nn.Linear(pred_dim, dim) # output layer
        )
        self.criterion = nn.CosineSimilarity(dim=1)
        
    def forward(self, im_q, im_k, *args,**kwargs):
        """
        Input:
            im_q: first view of images
            im_k: second view of images
        Output:
            loss, p1, z2 (for compatibility with MoCo interface)
        """
        # compute features for both views
        z1 = self.encoder_q(im_q) # NxC
        z2 = self.encoder_q(im_k) # NxC
        
        # compute predictions
        p1 = self.predictor(z1) # NxC
        p2 = self.predictor(z2) # NxC
        
        # 注意：z1和z2已经通过最后的BN层归一化，但我们仍然detach它们
        z1 = z1.detach()
        z2 = z2.detach()
        
        # compute loss
        # negative cosine similarity (SimSiam loss)
        loss = -0.5 * (self.criterion(p1, z2).mean() +  self.criterion(p2, z1).mean())
        
        # For compatibility with the rest of the code
        # We use p1 as logits and z2 as target for metrics calculation
        return [loss,loss,torch.tensor(0.0)], None, None

class PredNext_clip(nn.Module):
    """
    Build a MoCo model with: a query encoder, a key encoder, and a queue
    https://arxiv.org/abs/1911.05722
    """

    def __init__(self, base_encoder, dim=2048, pred_dim=512, args=None,**kwargs):
        """
        dim: feature dimension (default: 128)
        K: queue size; number of negative keys (default: 65536)
        m: moco momentum of updating key encoder (default: 0.999)
        T: softmax temperature (default: 0.07)
        """
        super(PredNext_clip, self).__init__()
        # create the encoders
        # num_classes is the output fc dimension
        self.encoder_q = base_encoder(num_classes=dim)
        exchange_conv(self.encoder_q,args)
        self.pred_step=args.pred_step
        self.pred_alpha=args.pred_alpha
        self.encoder_q.sum_output=False
        prev_dim = self.encoder_q.fc.weight.shape[1]
        self.encoder_q.fc = nn.Sequential(
            nn.Linear(prev_dim, prev_dim, bias=False),
            nn.BatchNorm1d(prev_dim),
            nn.ReLU(inplace=True), # first layer
            nn.Linear(prev_dim, prev_dim, bias=False),
            nn.BatchNorm1d(prev_dim),
            nn.ReLU(inplace=True), # second layer
            nn.Linear(prev_dim, dim), # output layer
            nn.BatchNorm1d(dim, affine=False) # output layer normalization
        )
        self.encoder_q.fc[-2].bias.requires_grad = False # hack: not use bias as it is followed by BN
        
        # build a 2-layer predictor
        self.predictor_next = nn.Sequential(
            nn.Linear(dim, pred_dim, bias=False),
            nn.BatchNorm1d(pred_dim),
            nn.ReLU(inplace=True), # hidden layer
            nn.Linear(pred_dim, dim) # output layer
        )
        # build a 2-layer predictor
        self.predictor_clip = nn.Sequential(
            nn.Linear(dim, pred_dim, bias=False),
            nn.BatchNorm1d(pred_dim),
            nn.ReLU(inplace=True), # hidden layer
            nn.Linear(pred_dim, dim) # output layer
        )
        self.predictor_full = nn.Sequential(
            nn.Linear(dim, pred_dim, bias=False),
            nn.BatchNorm1d(pred_dim),
            nn.ReLU(inplace=True), # hidden layer
            nn.Linear(pred_dim, dim) # output layer
        )
        self.criterion = nn.CosineSimilarity(dim=1)
        
    def forward(self,im_q, im_k,epoch_ratio,*args,**kwargs):
        """
        Input:
            im_q: a batch of query images (previous timestep)
            im_k: a batch of key images (next timestep)
        Output:
            loss, features, labels
        """
        # Extract features from the current timestep
        # Extract features from the encoder
        batch_size,_,time_steps,channel,height,width=im_q.shape
        im_q_current_clip=im_q[:,0]
        im_k_current_clip=im_k[:,0]
        im_q_next_clip=im_q[:,1]
        im_k_next_clip=im_k[:,1]

        #方式共享bn
        fq = self.encoder_q(im_q_current_clip)  # Shape: [T, batch, C]
        fk = self.encoder_q(im_k_current_clip)  # Shape: [T, batch, C]
        with torch.no_grad():
            fq_next_clip = self.encoder_q(im_q_next_clip)  # Shape: [T, batch, C]
            fk_next_clip = self.encoder_q(im_k_next_clip)  # Shape: [T, batch, C]
 

        #==================================================
        fq_current_clip=fq.mean(0)
        fk_current_clip=fk.mean(0)
        fq_next_clip=fq_next_clip.mean(0)
        fk_next_clip=fk_next_clip.mean(0) 
        p_fq_current_clip=self.predictor_clip(fq_current_clip)
        p_fk_current_clip=self.predictor_clip(fk_current_clip) 

        fq_next_clip=fq_next_clip.detach()
        fk_next_clip=fk_next_clip.detach()
        loss_clip = -0.5*(self.criterion(p_fq_current_clip, fq_next_clip).mean() +  self.criterion(p_fk_current_clip, fk_next_clip).mean())

        # =============================================
        # Get dimensions
        time_steps, batch_size, feature_dim = fq.shape
        
        pred_step = self.pred_step
        if pred_step<0:
            pred_step = random.randint(1, time_steps-1)
        current_fq = fq[:-pred_step]  # Shape: [T-1, batch, C]
        next_fq = fq[pred_step:]      # Shape: [T-1, batch, C] 
        current_fk = fk[:-pred_step]  # Shape: [T-1, batch, C]
        next_fk = fk[pred_step:]      # Shape: [T-1, batch, C]
        
        
        
        current_zq = current_fq.reshape(-1, feature_dim)  # Shape: [(T-1)*batch, C]
        next_zq = next_fq.reshape(-1, feature_dim)        # Shape: [(T-1)*batch, C]
        current_zk = current_fk.reshape(-1, feature_dim)  # Shape: [(T-1)*batch, C]
        next_zk = next_fk.reshape(-1, feature_dim)        # Shape: [(T-1)*batch, C]
        
                 
        # Apply predictor to current features to predict next features
        pq = self.predictor_next(current_zq) # Shape: [(T-1)*batch, C]
        pk = self.predictor_next(current_zk) # Shape: [(T-1)*batch, C]
        
        next_zq = next_zq.detach()
        next_zk = next_zk.detach()
        
        # compute loss
        # negative cosine similarity (SimSiam loss)
        loss_next = -0.5*(self.criterion(pq, next_zk).mean() +  self.criterion(pk, next_zq).mean())


        # =============================================
        
        fq=fq.mean(0)
        fk=fk.mean(0)
        p_fq=self.predictor_full(fq)
        p_fk=self.predictor_full(fk)
        
        fq = fq.detach()
        fk = fk.detach()
        
        loss_full = -0.5*(self.criterion(p_fq, fk).mean() +  self.criterion(p_fk, fq).mean())
        loss=self.pred_alpha*(loss_next*0.2+loss_clip*0.8)+(1-self.pred_alpha)*loss_full
        return [loss,loss_full,loss_next,loss_clip], None, None

class BYOL(nn.Module):
    """
    Build a BYOL model with: an online network, a target network, and a predictor
    https://arxiv.org/abs/2006.07733
    """
    def __init__(self, base_encoder, dim=2048, pred_dim=512, m=0.996, args=None, **kwargs):
        """
        dim: feature dimension (default: 2048)
        pred_dim: hidden dimension of the predictor (default: 512)
        m: momentum update parameter for target network (default: 0.996)
        """
        super(BYOL, self).__init__()
        
        self.m = m  # momentum update parameter
        
        # create the encoders
        # num_classes is the output fc dimension
        self.encoder_q = base_encoder(num_classes=dim)
        self.encoder_k = base_encoder(num_classes=dim)
        
        exchange_conv(self.encoder_q, args)
        exchange_conv(self.encoder_k, args)
        
        # build a 3-layer projector for both online and target networks
        prev_dim = self.encoder_q.fc.weight.shape[1]
        self.encoder_q.fc = nn.Sequential(
            nn.Linear(prev_dim, prev_dim, bias=False),
            nn.BatchNorm1d(prev_dim),
            nn.ReLU(inplace=True), # first layer
            nn.Linear(prev_dim, dim), # output layer
            nn.BatchNorm1d(dim, affine=False) # output layer normalization
        )
        self.encoder_q.fc[-2].bias.requires_grad = False # hack: not use bias as it is followed by BN
        
        self.encoder_k.fc = nn.Sequential(
            nn.Linear(prev_dim, prev_dim, bias=False),
            nn.BatchNorm1d(prev_dim),
            nn.ReLU(inplace=True), # first layer
            nn.Linear(prev_dim, dim), # output layer
            nn.BatchNorm1d(dim, affine=False) # output layer normalization
        )
        self.encoder_k.fc[-2].bias.requires_grad = False # hack: not use bias as it is followed by BN
        
        # build a 2-layer predictor
        self.predictor = nn.Sequential(
            nn.Linear(dim, pred_dim, bias=False),
            nn.BatchNorm1d(pred_dim),
            nn.ReLU(inplace=True), # hidden layer
            nn.Linear(pred_dim, dim) # output layer
        )
        
        # Initialize target network with same parameters as online network
        for param_online, param_target in zip(
            self.encoder_q.parameters(), self.encoder_k.parameters()
        ):
            param_target.data.copy_(param_online.data)
            param_target.requires_grad = False  # not update by gradient
            
        self.criterion = nn.CosineSimilarity(dim=1)
        
    @torch.no_grad()
    def _momentum_update_target_encoder(self):
        """
        Momentum update of the target encoder
        """
        for param_online, param_target in zip(
            self.encoder_q.parameters(), self.encoder_k.parameters()
        ):
            param_target.data = param_target.data * self.m + param_online.data * (1.0 - self.m)
    
    def forward(self, im_q, im_k, *args, **kwargs):
        """
        Input:
            im_1: first view of images
            im_2: second view of images
        Output:
            loss, logits, labels (for compatibility with other methods)
        """
        # Compute online features
        online_feat_1 = self.encoder_q(im_q)  # NxC
        online_feat_2 = self.encoder_q(im_k)  # NxC
        
        # Compute online predictions
        online_pred_1 = self.predictor(online_feat_1)  # NxC
        online_pred_2 = self.predictor(online_feat_2)  # NxC
        
        # Update the target encoder with momentum
        self._momentum_update_target_encoder()
        
        # Compute target features (no gradients needed)
        with torch.no_grad():
            target_feat_1 = self.encoder_k(im_q)  # NxC
            target_feat_2 = self.encoder_k(im_k)  # NxC
        
        # BYOL's loss: negative cosine similarity between predictions and targets
        # Online network predicts target network's features
        target_feat_1 = target_feat_1.detach()
        target_feat_2 = target_feat_2.detach()
        
        loss_1 = -self.criterion(online_pred_1, target_feat_2).mean()
        loss_2 = -self.criterion(online_pred_2, target_feat_1).mean()
        loss = 0.5 * (loss_1 + loss_2)
        
        # For compatibility with the rest of the code
        return [loss,loss,torch.tensor(0.0)], None, None

class BYOL_PredNext_clip(nn.Module):
    """
    Build a BYOL model with: an online network, a target network, and a predictor
    https://arxiv.org/abs/2006.07733
    """
    def __init__(self, base_encoder, dim=2048, pred_dim=512, m=0.996, args=None, **kwargs):
        """
        dim: feature dimension (default: 2048)
        pred_dim: hidden dimension of the predictor (default: 512)
        m: momentum update parameter for target network (default: 0.996)
        """
        super(BYOL_PredNext_clip, self).__init__()
        
        self.m = m  # momentum update parameter
        
        # create the encoders
        # num_classes is the output fc dimension
        self.encoder_q = base_encoder(num_classes=dim)
        self.encoder_k = base_encoder(num_classes=dim)
        
        exchange_conv(self.encoder_q, args)
        exchange_conv(self.encoder_k, args)
        
        # build a 3-layer projector for both online and target networks
        prev_dim = self.encoder_q.fc.weight.shape[1]
        self.encoder_q.fc = nn.Sequential(
            nn.Linear(prev_dim, prev_dim, bias=False),
            nn.BatchNorm1d(prev_dim),
            nn.ReLU(inplace=True), # first layer
            nn.Linear(prev_dim, dim), # output layer
            nn.BatchNorm1d(dim, affine=False) # output layer normalization
        )
        self.encoder_q.fc[-2].bias.requires_grad = False # hack: not use bias as it is followed by BN
        
        self.encoder_k.fc = nn.Sequential(
            nn.Linear(prev_dim, prev_dim, bias=False),
            nn.BatchNorm1d(prev_dim),
            nn.ReLU(inplace=True), # first layer
            nn.Linear(prev_dim, dim), # output layer
            nn.BatchNorm1d(dim, affine=False) # output layer normalization
        )
        self.encoder_k.fc[-2].bias.requires_grad = False # hack: not use bias as it is followed by BN
        
        # build a 2-layer predictor
        self.predictor = nn.Sequential(
            nn.Linear(dim, pred_dim, bias=False),
            nn.BatchNorm1d(pred_dim),
            nn.ReLU(inplace=True), # hidden layer
            nn.Linear(pred_dim, dim) # output layer
        )
        self.predictor_clip = nn.Sequential(
            nn.Linear(dim, pred_dim, bias=False),
            nn.BatchNorm1d(pred_dim),
            nn.ReLU(inplace=True), # hidden layer
            nn.Linear(pred_dim, dim) # output layer
        )        
        # Initialize target network with same parameters as online network
        for param_online, param_target in zip(
            self.encoder_q.parameters(), self.encoder_k.parameters()
        ):
            param_target.data.copy_(param_online.data)
            param_target.requires_grad = False  # not update by gradient
        

        # # build a 2-layer predictor
        # self.predictor_next = nn.Sequential(
        #     nn.Linear(dim, pred_dim, bias=False),
        #     nn.BatchNorm1d(pred_dim),
        #     nn.ReLU(inplace=True), # hidden layer
        #     nn.Linear(pred_dim, dim) # output layer
        # )

        self.pred_step=args.pred_step
        self.pred_alpha=args.pred_alpha
        self.encoder_q.sum_output=False


        self.criterion = nn.CosineSimilarity(dim=1)
        
    @torch.no_grad()
    def _momentum_update_target_encoder(self):
        """
        Momentum update of the target encoder
        """
        for param_online, param_target in zip(
            self.encoder_q.parameters(), self.encoder_k.parameters()
        ):
            param_target.data = param_target.data * self.m + param_online.data * (1.0 - self.m)
    
    def forward(self, im_q, im_k,epoch_ratio,*args, **kwargs):
        """
        Input:
            im_1: first view of images
            im_2: second view of images
        Output:
            loss, logits, labels (for compatibility with other methods)
        """
        # Compute online features

        im_q_next_clip=im_q[:,1]
        im_k_next_clip=im_k[:,1]
        im_q=im_q[:,0]
        im_k=im_k[:,0]

        online_feat_1 = self.encoder_q(im_q)  # NxC
        online_feat_2 = self.encoder_q(im_k)  # NxC
        with torch.no_grad():
            online_feat_1_next_clip = self.encoder_q(im_q_next_clip)  # Shape: [T, batch, C]
            online_feat_2_next_clip = self.encoder_q(im_k_next_clip)  # Shape: [T, batch, C]
        
        # loss_next=self.pred_next(online_feat_1,online_feat_2)
        loss_next = torch.tensor(0.0)
        loss_clip=self.pred_clip(online_feat_1,online_feat_2,online_feat_1_next_clip,online_feat_2_next_clip)
        online_feat_1=online_feat_1.mean(0)
        online_feat_2=online_feat_2.mean(0)
        
        # Compute online predictions
        online_pred_1 = self.predictor(online_feat_1)  # NxC
        online_pred_2 = self.predictor(online_feat_2)  # NxC
        
        # Update the target encoder with momentum
        self._momentum_update_target_encoder()
        
        # Compute target features (no gradients needed)
        with torch.no_grad():
            target_feat_1 = self.encoder_k(im_q)  # NxC
            target_feat_2 = self.encoder_k(im_k)  # NxC
        
        # BYOL's loss: negative cosine similarity between predictions and targets
        # Online network predicts target network's features
        target_feat_1 = target_feat_1.detach()
        target_feat_2 = target_feat_2.detach()
        
        loss_1 = -self.criterion(online_pred_1, target_feat_2).mean()
        loss_2 = -self.criterion(online_pred_2, target_feat_1).mean()
        loss_full = 0.5 * (loss_1 + loss_2)
        
        loss=self.pred_alpha*(loss_next*0.2+loss_clip*0.8)+(1-self.pred_alpha)*loss_full
        # For compatibility with the rest of the code
        return [loss,loss_full,loss_next,loss_clip], None, None

    def pred_next(self,fq,fk):
        

        # Get dimensions
        time_steps, batch_size, feature_dim = fq.shape
        
        pred_step = self.pred_step
        if pred_step<0:
            pred_step = random.randint(1, time_steps-1)
        current_fq = fq[:-pred_step]  # Shape: [T-1, batch, C]
        next_fq = fq[pred_step:]      # Shape: [T-1, batch, C] 
        current_fk = fk[:-pred_step]  # Shape: [T-1, batch, C]
        next_fk = fk[pred_step:]      # Shape: [T-1, batch, C]
        
        
        
        current_zq = current_fq.reshape(-1, feature_dim)  # Shape: [(T-1)*batch, C]
        next_zq = next_fq.reshape(-1, feature_dim)        # Shape: [(T-1)*batch, C]
        current_zk = current_fk.reshape(-1, feature_dim)  # Shape: [(T-1)*batch, C]
        next_zk = next_fk.reshape(-1, feature_dim)        # Shape: [(T-1)*batch, C]
        
                 
        # Apply predictor to current features to predict next features
        pq = self.predictor_next(current_zq) # Shape: [(T-1)*batch, C]
        pk = self.predictor_next(current_zk) # Shape: [(T-1)*batch, C]
        
        next_zq = next_zq.detach()
        next_zk = next_zk.detach()
        
        # compute loss
        # negative cosine similarity (SimSiam loss)
        loss_next = -0.5*(self.criterion(pq, next_zk).mean() +  self.criterion(pk, next_zq).mean())
        return loss_next

    def pred_clip(self,fq,fk,fq_next_clip,fk_next_clip):

        fq_current_clip=fq.mean(0)
        fk_current_clip=fk.mean(0)
        fq_next_clip=fq_next_clip.mean(0)
        fk_next_clip=fk_next_clip.mean(0) 
        p_fq_current_clip=self.predictor_clip(fq_current_clip)
        p_fk_current_clip=self.predictor_clip(fk_current_clip) 


        fq_next_clip=fq_next_clip.detach()
        fk_next_clip=fk_next_clip.detach()
        loss_clip = -0.5*(self.criterion(p_fq_current_clip, fq_next_clip).mean() +  self.criterion(p_fk_current_clip, fk_next_clip).mean())
        return loss_clip

class BarlowTwins(nn.Module):
    def __init__(self, base_encoder, dim=512, hidden_dim=2048, lambd=0.0051, args=None, **kwargs):
        super(BarlowTwins, self).__init__()
        
        self.lambd = lambd  # 用于平衡invariance和redundancy reduction的权重
        
        # create the encoder
        self.encoder_q = base_encoder(num_classes=dim)
        exchange_conv(self.encoder_q, args)
        dim_mlp = self.encoder_q.fc.weight.shape[1]
        
        # Barlow Twins使用更大的投影头
        self.encoder_q.fc = nn.Sequential(
            nn.Linear(dim_mlp, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, dim)
        )
    
    def forward(self, im_q, im_k, *args, **kwargs):
        # Extract features for both views
        z_a = self.encoder_q(im_q)  # first view: NxC
        z_b = self.encoder_q(im_k)  # second view: NxC
        
        
        # Check if distributed training is available
        if torch.distributed.is_available() and torch.distributed.is_initialized():
            # Gather representations from all GPUs
            z_a_gather = self.concat_all_gather_BarlowTwins(z_a)
            z_b_gather = self.concat_all_gather_BarlowTwins(z_b)
        else:
            # Single GPU training
            z_a_gather = z_a
            z_b_gather = z_b
        
        # Normalize the representations along the batch dimension
        z_a_norm = (z_a_gather - z_a_gather.mean(0)) / (z_a_gather.std(0) + 1e-5)
        z_b_norm = (z_b_gather - z_b_gather.mean(0)) / (z_b_gather.std(0) + 1e-5)
        
        # Cross-correlation matrix
        N = z_a_gather.shape[0]  # total batch size across all GPUs
        D = z_a_gather.shape[1]  # feature dimension
        
        # Compute cross-correlation matrix
        c = torch.mm(z_a_norm.T, z_b_norm) / N  # DxD
        
        # Barlow Twins loss
        # Diagonal terms: invariance term (should be close to 1)
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
        
        # Off-diagonal terms: redundancy reduction term (should be close to 0)
        off_diag = self.off_diagonal(c).pow_(2).sum()
        
        loss = on_diag + self.lambd * off_diag
        
        return [loss,loss,torch.tensor(0.0)], None, None  # 保持与SimCLR相同的返回格式

    def off_diagonal(self,x):
        """
        Return a flattened view of the off-diagonal elements of a square matrix
        """
        n, m = x.shape
        assert n == m
        return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
        
    def concat_all_gather_BarlowTwins(self,tensor):
        """
        在所有GPU间收集tensor并拼接，支持梯度回传
        """
        if not (torch.distributed.is_available() and torch.distributed.is_initialized()):
            return tensor
            
        tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())]
        torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
        
        # 为当前GPU的tensor保留梯度
        tensors_gather[torch.distributed.get_rank()] = tensor
        output = torch.cat(tensors_gather, dim=0)
        return output

class BarlowTwins_PredNext_clip(nn.Module):
    def __init__(self, base_encoder, dim=512, hidden_dim=2048, lambd=0.0051, pred_dim=512, args=None, **kwargs):
        super(BarlowTwins_PredNext_clip, self).__init__()
        
        self.lambd = lambd  # 用于平衡invariance和redundancy reduction的权重
        
        # create the encoder
        self.encoder_q = base_encoder(num_classes=dim)
        exchange_conv(self.encoder_q, args)
        dim_mlp = self.encoder_q.fc.weight.shape[1]
        
        # Barlow Twins使用更大的投影头
        self.encoder_q.fc = nn.Sequential(
            nn.Linear(dim_mlp, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, dim)
        )
        # build a 2-layer predictor
        # self.predictor_next = nn.Sequential(
        #     nn.BatchNorm1d(dim, affine=False), # output layer normalization
        #     nn.Linear(dim, pred_dim, bias=False),
        #     nn.BatchNorm1d(pred_dim),
        #     nn.ReLU(inplace=True), # hidden layer
        #     nn.Linear(pred_dim, dim) # output layer
        # )
        self.clip_bn=nn.BatchNorm1d(dim, affine=False)
        self.predictor_clip = nn.Sequential(
            nn.Linear(dim, pred_dim, bias=False),
            nn.BatchNorm1d(pred_dim),
            nn.ReLU(inplace=True), # hidden layer
            nn.Linear(pred_dim, dim) # output layer
        )
        self.pred_step=args.pred_step
        self.pred_alpha=args.pred_alpha
        self.encoder_q.sum_output=False
        self.criterion = nn.CosineSimilarity(dim=1)
    
    def forward(self, im_q, im_k,epoch_ratio,*args, **kwargs):
        # Extract features for both views
        im_q_next_clip=im_q[:,1]
        im_k_next_clip=im_k[:,1]
        im_q=im_q[:,0]
        im_k=im_k[:,0]

        z_a = self.encoder_q(im_q)  # first view: NxC
        z_b = self.encoder_q(im_k)  # second view: NxC
        with torch.no_grad():
            z_a_next_clip = self.encoder_q(im_q_next_clip)  # Shape: [T, batch, C]
            z_b_next_clip = self.encoder_q(im_k_next_clip)  # Shape: [T, batch, C]
        # loss_next=self.pred_next(z_a,z_b)
        # loss_next=loss_next*60
        loss_next = torch.tensor(0.0)
        loss_clip=self.pred_clip(self.clip_bn(z_a),self.clip_bn(z_b),self.clip_bn(z_a_next_clip),self.clip_bn(z_b_next_clip))
        loss_clip=loss_clip*60
        z_a=z_a.mean(0)
        z_b=z_b.mean(0)

        # Check if distributed training is available
        if torch.distributed.is_available() and torch.distributed.is_initialized():
            # Gather representations from all GPUs
            z_a_gather = self.concat_all_gather_BarlowTwins(z_a)
            z_b_gather = self.concat_all_gather_BarlowTwins(z_b)
        else:
            # Single GPU training
            z_a_gather = z_a
            z_b_gather = z_b
        
        # Normalize the representations along the batch dimension
        z_a_norm = (z_a_gather - z_a_gather.mean(0)) / (z_a_gather.std(0) + 1e-5)
        z_b_norm = (z_b_gather - z_b_gather.mean(0)) / (z_b_gather.std(0) + 1e-5)
        
        # Cross-correlation matrix
        N = z_a_gather.shape[0]  # total batch size across all GPUs
        D = z_a_gather.shape[1]  # feature dimension
        
        # Compute cross-correlation matrix
        c = torch.mm(z_a_norm.T, z_b_norm) / N  # DxD
        
        # Barlow Twins loss
        # Diagonal terms: invariance term (should be close to 1)
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
        
        # Off-diagonal terms: redundancy reduction term (should be close to 0)
        off_diag = self.off_diagonal(c).pow_(2).sum()
        
        loss_full = on_diag + self.lambd * off_diag

        loss=self.pred_alpha*(loss_next*0.2+loss_clip*0.8)+(1-self.pred_alpha)*loss_full
        
        return [loss,loss_full,loss_next,loss_clip], None, None  # 保持与SimCLR相同的返回格式

    def off_diagonal(self,x):
        """
        Return a flattened view of the off-diagonal elements of a square matrix
        """
        n, m = x.shape
        assert n == m
        return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
        
    def concat_all_gather_BarlowTwins(self,tensor):
        """
        在所有GPU间收集tensor并拼接，支持梯度回传
        """
        if not (torch.distributed.is_available() and torch.distributed.is_initialized()):
            return tensor
            
        tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())]
        torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
        
        # 为当前GPU的tensor保留梯度
        tensors_gather[torch.distributed.get_rank()] = tensor
        output = torch.cat(tensors_gather, dim=0)
        return output

    def pred_next(self,fq,fk):
        

        # Get dimensions
        time_steps, batch_size, feature_dim = fq.shape
        
        pred_step = self.pred_step
        if pred_step<0:
            pred_step = random.randint(1, time_steps-1)
        current_fq = fq[:-pred_step]  # Shape: [T-1, batch, C]
        next_fq = fq[pred_step:]      # Shape: [T-1, batch, C] 
        current_fk = fk[:-pred_step]  # Shape: [T-1, batch, C]
        next_fk = fk[pred_step:]      # Shape: [T-1, batch, C]
        
        
        
        current_zq = current_fq.reshape(-1, feature_dim)  # Shape: [(T-1)*batch, C]
        next_zq = next_fq.reshape(-1, feature_dim)        # Shape: [(T-1)*batch, C]
        current_zk = current_fk.reshape(-1, feature_dim)  # Shape: [(T-1)*batch, C]
        next_zk = next_fk.reshape(-1, feature_dim)        # Shape: [(T-1)*batch, C]
        
                 
        # Apply predictor to current features to predict next features
        pq = self.predictor_next(current_zq) # Shape: [(T-1)*batch, C]
        pk = self.predictor_next(current_zk) # Shape: [(T-1)*batch, C]
        
        next_zq = next_zq.detach()
        next_zk = next_zk.detach()
        
        # compute loss
        # negative cosine similarity (SimSiam loss)
        loss_next = -0.5*(self.criterion(pq, next_zk).mean() +  self.criterion(pk, next_zq).mean())
        return loss_next
    def pred_clip(self,fq,fk,fq_next_clip,fk_next_clip):

        fq_current_clip=fq.mean(0)
        fk_current_clip=fk.mean(0)
        fq_next_clip=fq_next_clip.mean(0)
        fk_next_clip=fk_next_clip.mean(0) 
        p_fq_current_clip=self.predictor_clip(fq_current_clip)
        p_fk_current_clip=self.predictor_clip(fk_current_clip) 

        fq_next_clip=fq_next_clip.detach()
        fk_next_clip=fk_next_clip.detach()
        loss_clip = -0.5*(self.criterion(p_fq_current_clip, fq_next_clip).mean() +  self.criterion(p_fk_current_clip, fk_next_clip).mean())
        return loss_clip